import logging
import random
import re
import tempfile
from functools import partial
from typing import Optional

import click
from datasets import load_dataset
from PIL import Image
from pydantic import BaseModel, model_validator

from tensorrt_llm.bench.dataset.utils import (
    generate_multimodal_dataset,
    generate_text_dataset,
    get_norm_dist_lengths,
    write_dataset_to_file,
)


def validate_output_len_dist(ctx, param, value):
    """Validate the --output-len-dist option."""
    if value is None:
        return value
    m = re.match(r"(\d+),(\d+)", value)
    if m:
        return int(m.group(1)), int(m.group(2))
    else:
        raise AssertionError(
            "Incorrect specification for --output-len-dist. Correct format: "
            "--output-len-dist <output_len_mean>,<output_len_stdev>"
        )


class DatasetConfig(BaseModel):
    """Dataset configurations."""

    """Name of the dataset on HuggingFace."""
    name: str
    """Config name of the dataset if existing."""
    config_name: Optional[str] = None
    """Split of the dataset. Typical values: train, validation, test. Setting to None will include all splits."""
    split: Optional[str]
    """The dataset dictionary used for the input sentence."""
    input_key: Optional[str] = None
    """The dataset dictionary key used for the prompt of the input sentence. Must not be set when prompt is set."""
    image_key: Optional[str] = None
    """The dataset dictionary key used for the images."""
    prompt_key: Optional[str] = None
    """The prompt sentence to be added to the input sentence. Must not be set when prompt_key is set."""
    prompt: Optional[str] = None
    """The dataset dictionary key used to derive the output sequence length. Set to None if no output key."""
    output_key: Optional[str]

    @model_validator(mode="after")
    def check_prompt(self) -> "DatasetConfig":
        if self.prompt_key and self.prompt:
            raise AssertionError("--prompt-key and --prompt cannot be set at the same time.")
        if (not self.prompt_key) and (not self.prompt):
            raise AssertionError("Either --prompt-key or --prompt must be set.")
        return self

    @property
    def query(self):
        """Generate the query for HuggingFace `datasets.load_dataset()`."""
        if self.config_name:
            return [self.name, self.config_name]
        else:
            return [self.name]

    def get_prompt(self, req):
        """Get the prompt sentence from the given request."""
        if self.prompt_key:
            assert self.prompt_key in req, (
                f"Dataset {self.name} does not have key '{self.prompt_key}'. "
                "Please set --prompt-key to one of the available keys: "
                f"{req.keys()}"
            )
            return req[self.prompt_key]
        else:
            return self.prompt

    def get_input(self, req):
        """Get the input sentence from the given request."""
        assert self.input_key in req, (
            f"Dataset {self.name} does not have key '{self.input_key}'. "
            "Please set --input-key to one of the available keys: "
            f"{req.keys()}"
        )
        return req[self.input_key]

    def get_images(self, req):
        """Get the images from the given request."""
        image_keys = [self.image_key] + [f"{self.image_key}_{i}" for i in range(1, 8)]
        assert any(key in req for key in image_keys), (
            f"Dataset {self.name} does not have key '{self.image_key}'. "
            "Please set --dataset-image-key to one of the available keys: "
            f"{req.keys()}"
        )
        images = []
        for key in image_keys:
            if key in req and req[key] is not None:
                images.append(req[key])
        return images

    def get_output(self, req):
        """Get the output sentence from the given request."""
        if self.output_key is None:
            raise RuntimeError(
                "--output-key is not set. Please either:\n"
                "1. Define output length through --output-len-dist.\n"
                f"2. If the dataset {self.name} has key for golden output and "
                "you wish to set output length to the length of the golden "
                "output, set --output-key."
            )
        assert self.output_key in req, (
            f"Dataset {self.name} does not have key '{self.output_key}'. "
            "Please set --output-key to one of the available keys: "
            f"{req.keys()}"
        )
        return req[self.output_key]


def load_dataset_from_hf(dataset_config: DatasetConfig):
    """Load dataset from HuggingFace.

    Args:
        dataset_config: A `DatasetConfig` object that defines the dataset to load.

    Returns:
        Dataset iterator.

    Raises:
        ValueError: When dataset loading fails due to incorrect dataset config setting.
    """
    try:
        dataset = iter(
            load_dataset(
                *dataset_config.query,
                split=dataset_config.split,
                streaming=True,
                trust_remote_code=True,
            )
        )
    except ValueError as e:
        if "Config" in e:
            e += "\n Please add the config name to the dataset config yaml."
        elif "split" in e:
            e += "\n Please specify supported split in the dataset config yaml."
        raise ValueError(e)

    return dataset


@click.command(name="real-dataset")
@click.option("--dataset-name", required=True, type=str, help="Dataset name in HuggingFace.")
@click.option(
    "--dataset-config-name",
    type=str,
    default=None,
    help="Dataset config name in HuggingFace (if exists).",
)
@click.option("--dataset-split", type=str, required=True, help="Split of the dataset to use.")
@click.option("--dataset-input-key", type=str, help="The dataset dictionary key for input.")
@click.option(
    "--dataset-image-key", type=str, default="image", help="The dataset dictionary key for images."
)
@click.option(
    "--dataset-prompt-key",
    type=str,
    default=None,
    help="The dataset dictionary key for prompt (if exists).",
)
@click.option(
    "--dataset-prompt",
    type=str,
    default=None,
    help="The prompt string when there is no prompt key for the dataset.",
)
@click.option(
    "--dataset-output-key",
    type=str,
    default=None,
    help="The dataset dictionary key for output (if exists).",
)
@click.option(
    "--num-requests",
    type=int,
    default=None,
    help="Number of requests to be generated. Will be capped to min(dataset.num_rows, num_requests).",
)
@click.option(
    "--max-input-len",
    type=int,
    default=None,
    help="Maximum input sequence length for a given request. This will be used to filter out the "
    "requests with long input sequence length. Default will include all the requests.",
)
@click.option(
    "--output-len-dist",
    type=str,
    default=None,
    callback=validate_output_len_dist,
    help="Output length distribution. Default will be the length of the golden output from "
    "the dataset. Format: <output_len_mean>,<output_len_stdev>. E.g. 100,10 will randomize "
    "the output length with mean=100 and variance=10.",
)
@click.pass_obj
def real_dataset(root_args, **kwargs):
    """Prepare dataset from real dataset."""
    dataset_config = DatasetConfig(
        **{k[8:]: v for k, v in kwargs.items() if k.startswith("dataset_")}
    )

    input_ids = []
    input_lens = []
    output_lens = []
    task_ids = []
    req_cnt = 0
    modality = None
    multimodal_texts = []
    multimodal_image_paths = []
    for req in load_dataset_from_hf(dataset_config):
        if any(key in req for key in ["image", "image_1", "video"]):
            # multimodal input
            if "video" in req and req["video"] is not None:
                assert "Not supported yet"
            assert kwargs["output_len_dist"] is not None, (
                "Output length distribution must be set for multimodal requests."
            )
            modality = "image"
            text = dataset_config.get_prompt(req)
            images = dataset_config.get_images(req)
            image_paths = []
            for image in images:
                if image is not None:
                    if isinstance(image, str):
                        image_paths.append(image)
                    elif isinstance(image, Image.Image):
                        with tempfile.NamedTemporaryFile(suffix=".jpg", delete=False) as tmp_file:
                            logging.debug(f"Saving image to {tmp_file.name}")
                            image = image.convert("RGB")
                            image.save(tmp_file, "JPEG")
                            filepath = tmp_file.name
                            image_paths.append(filepath)
                    else:
                        raise ValueError(f"Invalid image path: {image}")
            multimodal_texts.append(text)
            multimodal_image_paths.append(image_paths)
        else:
            # text input
            prompt = dataset_config.get_prompt(req) + " " + dataset_config.get_input(req)
            logging.debug(f"Input sequence: {prompt}")
            line = root_args.tokenizer.encode(prompt)
            if kwargs["max_input_len"] and len(line) > kwargs["max_input_len"]:
                continue
            input_ids.append(line)
            input_lens.append(len(line))

            # output if fetch from golden
            if kwargs["output_len_dist"] is None:
                output_lens.append(len(root_args.tokenizer.encode(dataset_config.get_output(req))))

        # lora task id
        task_id = root_args.task_id
        if root_args.rand_task_id is not None:
            min_id, max_id = root_args.rand_task_id
            task_id = random.randint(min_id, max_id)
        task_ids.append(task_id)

        req_cnt += 1
        if kwargs["num_requests"] and req_cnt >= kwargs["num_requests"]:
            break

    if (
        kwargs["num_requests"]
        and (len(input_ids) if modality is None else len(multimodal_texts)) < kwargs["num_requests"]
    ):
        logging.warning(
            f"Number of requests={len(input_ids) if modality is None else len(multimodal_texts)} is"
            f" smaller than the num-requests user set={kwargs['num_requests']}."
        )

    # output if randomized
    if kwargs["output_len_dist"] is not None:
        osl_mean, osl_stdev = kwargs["output_len_dist"]
        output_lens = get_norm_dist_lengths(
            osl_mean,
            osl_stdev,
            len(input_ids) if modality is None else len(multimodal_texts),
            root_args.random_seed,
        )
    logging.debug(f"Input lengths: {[len(i) for i in input_ids]}")
    logging.debug(f"Output lengths: {output_lens}")
    if modality is not None:
        logging.debug(f"Modality: {modality}")

    dataset_generator = None
    if modality is not None:
        dataset_generator = partial(
            generate_multimodal_dataset, multimodal_texts, multimodal_image_paths
        )
    else:
        dataset_generator = partial(generate_text_dataset, input_ids)
    write_dataset_to_file(dataset_generator(output_lens), root_args.output)
